// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>

#include <iostream>
#include <limits>

#include "paddle/extension.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"

// 简单的类型提取
template <paddle::DataType DType>
struct TypeMap;
template <>
struct TypeMap<paddle::DataType::BFLOAT16> {
  using type = phi::bfloat16;
};
template <>
struct TypeMap<paddle::DataType::FLOAT16> {
  using type = phi::float16;
};
template <>
struct TypeMap<paddle::DataType::FLOAT32> {
  using type = float;
};
template <>
struct TypeMap<paddle::DataType::INT32> {
  using type = int;
};
template <>
struct TypeMap<paddle::DataType::INT64> {
  using type = int64_t;
};


// cuda结构体拷贝
template <typename T, int N>
struct alignas(16) VectorType {
  T data[N];
};

// 128Byte对齐的结构体
template <>
struct alignas(16) VectorType<float, 4> {
  float4 data;  // Built-in CUDA vector type
};

template <>
struct alignas(16) VectorType<__nv_bfloat16, 8> {
  __nv_bfloat16 data[8];
};

template <>
struct alignas(16) VectorType<__nv_fp8_e4m3, 16> {
  __nv_fp8_e4m3 data[16];
};

template <>
struct alignas(16) VectorType<uint8_t, 16> {
  uint8_t data[16];
};

// Helper function to perform vectorized memory copy
template <typename T>
__device__ __forceinline__ void vectorized_memcpy(const T* src,
                                                  T* dst,
                                                  int num_elements) {
  constexpr int vector_size_in_bytes = 16;
  const int elements_per_vector = vector_size_in_bytes / sizeof(T);

  // 已知单行token向量化不会超过4G大小，用int节省整数开销
  int num_vectors = num_elements / elements_per_vector;
  int remaining_elements = num_elements % elements_per_vector;

  using VecType = VectorType<T, elements_per_vector>;
  const VecType* src_vec = reinterpret_cast<const VecType*>(src);
  VecType* dst_vec = reinterpret_cast<VecType*>(dst);

// 已知paddle框架中的显存分配均为256Bytes对齐，所以默认align
#pragma unroll
  for (int idx = threadIdx.x; idx < num_vectors; idx += blockDim.x) {
    dst_vec[idx] = src_vec[idx];
  }

  // 剩余无法向量化处理的元素
  if (remaining_elements > 0) {
    int offset = num_vectors * elements_per_vector;
    for (int i = threadIdx.x; i < remaining_elements; i += blockDim.x) {
      dst[offset + i] = src[offset + i];
    }
  }
}

#define PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, __max_num_experts, ...) \
  if (__num_expert <= __max_num_experts) {                               \
    constexpr auto MAX_NUM_EXPERTS_C = __max_num_experts;                \
    do {                                                                 \
      __VA_ARGS__();                                                     \
    } while (0);                                                         \
    break;                                                               \
  }


#define PD_SWITCH_NUM_EXPERTS(__num_experts_expr, ...)           \
  do {                                                           \
    auto __num_expert = (__num_experts_expr);                    \
    PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 8, __VA_ARGS__);    \
    PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 16, __VA_ARGS__);   \
    PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 32, __VA_ARGS__);   \
    PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 64, __VA_ARGS__);   \
    PD_THROW("Unsupported expert number %d", int(__num_expert)); \
  } while (0)
